{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Practical Deep Learning for Coders, v3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Lesson4_tabular"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tabular models\n",
    "# Tabular(表格)模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.tabular import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Tabular data should be in a Pandas `DataFrame`.\n",
    "\n",
    "Tabular数据是Pandas里的`DataFrame`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.ADULT_SAMPLE)\n",
    "df = pd.read_csv(path/'adult.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dep_var = 'salary'\n",
    "cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n",
    "cont_names = ['age', 'fnlwgt', 'education-num']\n",
    "procs = [FillMissing, Categorify, Normalize]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = TabularList.from_df(df.iloc[800:1000].copy(), path=path, cat_names=cat_names, cont_names=cont_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = (TabularList.from_df(df, path=path, cat_names=cat_names, cont_names=cont_names, procs=procs)\n",
    "                           .split_by_idx(list(range(800,1000)))\n",
    "                           .label_from_df(cols=dep_var)\n",
    "                           .add_test(test)\n",
    "                           .databunch())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table>  <col width='10%'>  <col width='10%'>  <col width='10%'>  <col width='10%'>  <col width='10%'>  <col width='10%'>  <col width='10%'>  <col width='10%'>  <col width='10%'>  <col width='10%'>  <col width='10%'>  <tr>\n",
       "    <th>workclass</th>\n",
       "    <th>education</th>\n",
       "    <th>marital-status</th>\n",
       "    <th>occupation</th>\n",
       "    <th>relationship</th>\n",
       "    <th>race</th>\n",
       "    <th>education-num_na</th>\n",
       "    <th>age</th>\n",
       "    <th>fnlwgt</th>\n",
       "    <th>education-num</th>\n",
       "    <th>target</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th> Private</th>\n",
       "    <th> HS-grad</th>\n",
       "    <th> Never-married</th>\n",
       "    <th> Sales</th>\n",
       "    <th> Not-in-family</th>\n",
       "    <th> White</th>\n",
       "    <th>False</th>\n",
       "    <th>-1.2158</th>\n",
       "    <th>1.1004</th>\n",
       "    <th>-0.4224</th>\n",
       "    <th><50k</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th> ?</th>\n",
       "    <th> HS-grad</th>\n",
       "    <th> Widowed</th>\n",
       "    <th> ?</th>\n",
       "    <th> Not-in-family</th>\n",
       "    <th> White</th>\n",
       "    <th>False</th>\n",
       "    <th>1.8627</th>\n",
       "    <th>0.0976</th>\n",
       "    <th>-0.4224</th>\n",
       "    <th><50k</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th> Self-emp-not-inc</th>\n",
       "    <th> HS-grad</th>\n",
       "    <th> Never-married</th>\n",
       "    <th> Craft-repair</th>\n",
       "    <th> Own-child</th>\n",
       "    <th> Black</th>\n",
       "    <th>False</th>\n",
       "    <th>0.0303</th>\n",
       "    <th>0.2092</th>\n",
       "    <th>-0.4224</th>\n",
       "    <th><50k</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th> Private</th>\n",
       "    <th> HS-grad</th>\n",
       "    <th> Married-civ-spouse</th>\n",
       "    <th> Protective-serv</th>\n",
       "    <th> Husband</th>\n",
       "    <th> White</th>\n",
       "    <th>False</th>\n",
       "    <th>1.5695</th>\n",
       "    <th>-0.5938</th>\n",
       "    <th>-0.4224</th>\n",
       "    <th><50k</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th> Private</th>\n",
       "    <th> HS-grad</th>\n",
       "    <th> Married-civ-spouse</th>\n",
       "    <th> Handlers-cleaners</th>\n",
       "    <th> Husband</th>\n",
       "    <th> White</th>\n",
       "    <th>False</th>\n",
       "    <th>-0.9959</th>\n",
       "    <th>-0.0318</th>\n",
       "    <th>-0.4224</th>\n",
       "    <th><50k</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th> Private</th>\n",
       "    <th> 10th</th>\n",
       "    <th> Married-civ-spouse</th>\n",
       "    <th> Farming-fishing</th>\n",
       "    <th> Wife</th>\n",
       "    <th> White</th>\n",
       "    <th>False</th>\n",
       "    <th>-0.7027</th>\n",
       "    <th>0.6071</th>\n",
       "    <th>-1.5958</th>\n",
       "    <th><50k</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th> Private</th>\n",
       "    <th> HS-grad</th>\n",
       "    <th> Married-civ-spouse</th>\n",
       "    <th> Machine-op-inspct</th>\n",
       "    <th> Husband</th>\n",
       "    <th> White</th>\n",
       "    <th>False</th>\n",
       "    <th>0.1036</th>\n",
       "    <th>-0.0968</th>\n",
       "    <th>-0.4224</th>\n",
       "    <th><50k</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th> Private</th>\n",
       "    <th> Some-college</th>\n",
       "    <th> Married-civ-spouse</th>\n",
       "    <th> Exec-managerial</th>\n",
       "    <th> Own-child</th>\n",
       "    <th> White</th>\n",
       "    <th>False</th>\n",
       "    <th>-0.7760</th>\n",
       "    <th>-0.6653</th>\n",
       "    <th>-0.0312</th>\n",
       "    <th>>=50k</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th> State-gov</th>\n",
       "    <th> Some-college</th>\n",
       "    <th> Never-married</th>\n",
       "    <th> Tech-support</th>\n",
       "    <th> Own-child</th>\n",
       "    <th> White</th>\n",
       "    <th>False</th>\n",
       "    <th>-0.8493</th>\n",
       "    <th>-1.4959</th>\n",
       "    <th>-0.0312</th>\n",
       "    <th><50k</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th> Private</th>\n",
       "    <th> 11th</th>\n",
       "    <th> Never-married</th>\n",
       "    <th> Machine-op-inspct</th>\n",
       "    <th> Not-in-family</th>\n",
       "    <th> White</th>\n",
       "    <th>False</th>\n",
       "    <th>-1.0692</th>\n",
       "    <th>-0.9516</th>\n",
       "    <th>-1.2046</th>\n",
       "    <th><50k</th>\n",
       "  </tr>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "data.show_batch(rows=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = tabular_learner(data, layers=[200,100], metrics=accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "Total time: 00:03 <p><table style='width:300px; margin-bottom:10px'>\n",
       "  <tr>\n",
       "    <th>epoch</th>\n",
       "    <th>train_loss</th>\n",
       "    <th>valid_loss</th>\n",
       "    <th>accuracy</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>1</th>\n",
       "    <th>0.354604</th>\n",
       "    <th>0.378520</th>\n",
       "    <th>0.820000</th>\n",
       "  </tr>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit(1, 1e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference  预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "row = df.iloc[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Category >=50k, tensor(1), tensor([0.4402, 0.5598]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learn.predict(row)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
